import torch
import torch.optim as optim
import torch.nn as nn
import argparse
import os
import itertools

from model import WideResNet
from data import get_cifar10_loaders, get_cifar100_loaders, get_cifar10_loaders_semi, get_cifar100_loaders_semi
from IAM import inconsistencyLoss, inconsistency_semi
from SAM import SAMLoss

def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    return 100. * correct / total

if __name__ == "__main__":
    os.environ["TMPDIR"] = "/home/intern/tmp"

    parser = argparse.ArgumentParser()
    parser.add_argument("--optimizer", default="IAM", type=str)
    parser.add_argument("--dropout", default=0.0, type=float)
    parser.add_argument("--ascent", default=0.1, type=float)
    parser.add_argument("--epochs", default=200, type=int)
    parser.add_argument("--lr", default=0.1, type=float)
    parser.add_argument("--beta", default=1.0, type=float)
    parser.add_argument("--dataset", default="CIFAR-10", type=str)
    parser.add_argument("--semi", default=False, type=bool) # Semi-Supervised
    args = parser.parse_args()

    print("===== TRAINING CONFIGURATION =====")
    for k, v in vars(args).items():
        print(f"{k:15s}: {v}")
    print("==================================\n")

    rho = args.ascent
    # CUDA
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Initialize dataloader
    if(not args.semi):
        if args.dataset == "CIFAR-10":
            train_loader, test_loader = get_cifar10_loaders()
            num_labels = 10
        elif args.dataset == "CIFAR-100":
            train_loader, test_loader = get_cifar100_loaders()
            num_labels = 100
    else:
        if args.dataset == "CIFAR-10":
            # Change val_split to modify missing label rate
            train_loader, val_loader, test_loader = get_cifar10_loaders_semi(val_split = 0.8)
            num_labels = 10
            val_loader = itertools.cycle(val_loader)
        elif args.dataset == "CIFAR-100":
            train_loader, val_loader, test_loader = get_cifar100_loaders_semi(val_split = 0.8)
            num_labels = 100
            val_loader = itertools.cycle(val_loader)

    model = WideResNet(depth=16, width_factor=8, dropout=args.dropout, in_channels=3, labels=num_labels)
    model = model.to(device)
    model_prime = WideResNet(depth=16, width_factor=8, dropout=args.dropout, in_channels=3, labels=num_labels).to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing = 0.1)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

    noise_scale = 3.0
    #Training Loop
    for epoch in range(args.epochs):
        
        total_loss = 0.0
        total_inconsistency = 0.0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            loss = 0.0
            inconsistency = 0.0

            if args.optimizer == "SAM":
                loss = SAMLoss(model, images, labels, criterion, optimizer, args.ascent)

            elif args.optimizer == "SGD":
                model.train()
                outputs = model(images)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            elif args.optimizer == "IAM":
                model.train()
                if args.semi:
                    val_images, val_labels = next(val_loader)
                    val_images = val_images.to(device)
                    loss, inconsistency = inconsistency_semi(model, images, val_images, labels, criterion, beta=args.beta, rho=rho, noise_scale=noise_scale)
                else:
                     loss, inconsistency = inconsistencyLoss(model, images, labels, criterion, beta=args.beta, rho=rho, noise_scale=noise_scale)
                loss += inconsistency
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                optimizer.step()
                total_inconsistency += inconsistency.item()

            total_loss += loss.item()
        scheduler.step()

        # Average Loss
        avg_loss = total_loss / len(train_loader)
        acc = evaluate(model)
        error = 100 - acc

        avg_inconsistency = total_inconsistency / len(train_loader)

        print(
            f"Epoch: {epoch}\t"
            f"Loss: {avg_loss:.4f}\t"
            f"Test Error: {error:.2f}%\t"
            f"Inconsistency: {avg_inconsistency:.4f}"
        )
